#!/usr/bin/env python
# coding: utf-8

# # Prerequises

# In[13]:


print("Loading dependencies", flush=True)
import numpy as np
import matplotlib.pyplot as plt
from mma import splx2bf, approx, from_dump
from classif_helper import *
import pickle


from sys import argv
start = int(argv[1])
stop = int(argv[2])
num = int(argv[3])
iterator = np.linspace(start=start, stop=stop, num=num, dtype=int)


# In[21]:
print("Retrieving modules...", flush=True)

with open(f"modules/cv_synthetic_module_{start}_{stop}_{num}.pkl", 'rb') as file:
    approximation_modules, params = pickle.load(file)

params["box"] = [[0.1,-0.1], [4,3]]
params["bandwidth"]=1
params["dimension"]=1
params["resolution"]=[50,50]
params["normalize"] = 1
params["ps"] = [0,1,2,np.inf]
print(params)

print("Converting module...", flush=True)
approximation_modules = [from_dump(mod) for mod in approximation_modules]

# distance between images
distances = [
	lambda x,y : np.square(x-y).mean(),
	lambda x,y : np.square(x-y).mean()/y.max(),
	lambda x,y : np.abs(x-y).max(),
	lambda x,y : np.abs(x-y).max()/y.max(),
]
distances_names=["L2 norm", "scaled L2 norm", "sup norm", "scaled sup norm"]

print("Computing images...", flush=True)
errors = np.zeros(shape=(len(params["ps"]), len(approximation_modules), len(distances)))
for i,p in enumerate(params["ps"]):
	plt.figure()
	last = approximation_modules[-1].image(p=p,plot=True,**params)
	plt.savefig(f"test_{p}.png")
	for j,mod in tqdm(enumerate(approximation_modules)):
		current = mod.image(p=p, plot=False,**params)
		for k,d in enumerate(distances):
			errors[i,j,k] = d(current, last)

print("Saving errors...", flush=True)
with open(f"errors/synthetic/errors_{start}_{stop}_{num}.pkl", 'wb') as file:
	pickle.dump(errors, file)

print("Saving plots...", flush=True)
for k,_ in enumerate(distances):
	plt.figure()
	for i,p in enumerate(params["ps"]):
		plt.plot(iterator[:-1], errors[i,:-1,k], label=f"p={p}")
	plt.xlabel("Number of points")
	plt.ylabel(distances_names[k])
	plt.legend()
	plt.savefig(f"images/synthetic/plot_{distances_names[k]}_{start}_{stop}_{num}.svg")
	plt.clf()


print("Done !")




